Skip to content

Conversation

@tkarna
Copy link
Contributor

@tkarna tkarna commented Nov 6, 2025

Adds transform.xegpu.set_op_layout_attr transform op that attaches xegpu.layout attribute to the target op.

Also adds getLayoutAttrFromOperands utility function.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Tuomas Kärnä (tkarna)

Changes

Adds transform.xegpu.set_op_layout_attr transform op that attaches xegpu.layout attribute to the target op.

Also adds getLayoutAttrFromOperands utility function.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.


Patch is 23.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166854.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+65)
  • (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+106-19)
  • (modified) mlir/python/mlir/dialects/transform/xegpu.py (+47)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir (+58)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+134)
  • (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+49)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..4e0eae1007c8f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
+def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute of an op.";
+  let description = [{
+    Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
+    `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
+    target operand/result value is defined by the `index` argument. The layout
+    is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface : $target,
+                   DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedAttr<UnitAttr, "false">:$result
+                   );
+
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "int64_t":$index,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData,
+                   CArg<"bool", "false">:$result
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target (`result` $result^)? (`index` `=` $index^)?
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..456cfb9ddd2bc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
       /*order=*/nullptr);
 }
 
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
+                          transform::TransformState &state,
+                          TransformOpInterface transformOp,
+                          ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+                          ArrayRef<::mlir::OpFoldResult> mixedSgData,
+                          ArrayRef<::mlir::OpFoldResult> mixedInstData,
+                          xegpu::LayoutAttr &layoutAttr) {
+  SmallVector<int32_t> sgLayout, sgData, instData;
+  auto status =
+      convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+  if (!status.succeeded())
+    return status;
+
+  status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+  if (!status.succeeded())
+    return status;
+
+  status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+  if (!status.succeeded())
+    return status;
+  auto maybeInstData = instData.empty()
+                           ? std::nullopt
+                           : std::optional<ArrayRef<int32_t>>(instData);
+
+  layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 /// Replace xegpu.create_nd_desc op with a new one with the given layout.
 static xegpu::CreateNdDescOp
 setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
   Operation *target = *targetOps.begin();
 
-  SmallVector<int32_t> sgLayout;
-  DiagnosedSilenceableFailure status =
-      convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
   if (!status.succeeded())
     return status;
 
-  SmallVector<int32_t> sgData;
-  status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
-  if (!status.succeeded())
-    return status;
-
-  SmallVector<int32_t> instData;
-  status =
-      convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
-  if (!status.succeeded())
-    return status;
-  auto maybeInstData = instData.empty()
-                           ? std::nullopt
-                           : std::optional<ArrayRef<int32_t>>(instData);
-
   // For now only create_nd_desc op is supported.
   auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
   if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto layoutAttr =
-      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
   auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
 
   // Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::SetOpLayoutAttrOp::build(
+    OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+    ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+    ArrayRef<OpFoldResult> mixedInstData, bool result) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*index=*/index,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData,
+        /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  bool resultTarget = getResult();
+
+  int64_t index = getIndex();
+  if (resultTarget && index >= target->getNumResults()) {
+    return emitSilenceableFailure(getLoc())
+           << "Index exceeds the number of op results";
+  }
+  if (!resultTarget && index >= target->getNumOperands()) {
+    return emitSilenceableFailure(getLoc())
+           << "Index exceeds the number of op operands";
+  }
+
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
+  if (!status.succeeded())
+    return status;
+
+  // Set layout attribute for the op result or operand
+  if (resultTarget) {
+    xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+  } else {
+    xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..46a1f032630d1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,50 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
+    """Specialization for SetOpLayoutAttrOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        *,
+        inst_data: MixedValues = None,
+        index: Union[int, Attribute] = None,
+        result: Union[bool, Attribute] = None,
+        loc=None,
+        ip=None,
+    ):
+        inst_data = [] if inst_data is None else inst_data
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+        super().__init__(
+            _get_op_result_or_value(target),
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            index=index,
+            result=result,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..726b6748452ae 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Index exceeds the number of op results}}
+    transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Index exceeds the number of op operands}}
+    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+    transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..089a8fb4fd9b6 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+  %2 = arith.extf %1 : vector<256x32...
[truncated]

@tkarna
Copy link
Contributor Author

tkarna commented Nov 6, 2025

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally seems good to me.

Some nits and a question if a helper function needs the interface that it has. Otherwise, looks fine to me to go in.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. A warning that a near future refactoring may impact this operation's interface.

ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
ArrayRef<::mlir::OpFoldResult> mixedSgData,
ArrayRef<::mlir::OpFoldResult> mixedInstData,
xegpu::LayoutAttr &layoutAttr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I prefer to add features when they are needed. It's also easy to add relevant test cases then.

let summary = "Set xegpu.layout attribute of an op.";
let description = [{
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.

@adam-smnk adam-smnk merged commit 94a7006 into llvm:main Nov 10, 2025
10 checks passed
@tkarna tkarna deleted the xegpu-tr-ops-set-op-layout-attr branch November 10, 2025 15:12
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 10, 2025

LLVM Buildbot has detected a new failure on builder ppc64le-mlir-rhel-clang running on ppc64le-mlir-rhel-test while building mlir at step 6 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/129/builds/32782

Here is the relevant piece of the build log for the reference
Step 6 (test-build-check-mlir-build-only-check-mlir) failure: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
...
PASS: MLIR :: mlir-tblgen/cpp-class-comments.td (3631 of 3642)
PASS: MLIR :: Pass/pipeline-parsing.mlir (3632 of 3642)
PASS: MLIR :: mlir-reduce/dce-test.mlir (3633 of 3642)
PASS: MLIR-Unit :: IR/./MLIRIRTests/0/130 (3634 of 3642)
PASS: MLIR-Unit :: Interfaces/./MLIRInterfacesTests/13/22 (3635 of 3642)
PASS: MLIR-Unit :: Interfaces/./MLIRInterfacesTests/12/22 (3636 of 3642)
PASS: MLIR-Unit :: Pass/./MLIRPassTests/10/13 (3637 of 3642)
PASS: MLIR-Unit :: Interfaces/./MLIRInterfacesTests/11/22 (3638 of 3642)
PASS: MLIR-Unit :: IR/./MLIRIRTests/38/130 (3639 of 3642)
PASS: MLIR-Unit :: IR/./MLIRIRTests/39/130 (3640 of 3642)
command timed out: 1200 seconds without output running [b'ninja', b'check-mlir'], attempting to kill
process killed by signal 9
program finished with exit code -1
elapsedTime=2414.160779

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants